# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from multiprocessing import Manager
from typing import Any
from typing import Callable
from typing import Dict
from typing import List

import numpy as np
from paddle.io import Dataset


class DataTable(Dataset):
    """Dataset to load and convert data for general purpose.
    Args:
        data (List[Dict[str, Any]]): Metadata, a list of meta datum, each of which is composed of  several fields
        fields (List[str], optional): Fields to use, if not specified, all the fields in the data are used, by default None
        converters (Dict[str, Callable], optional): Converters used to process each field, by default None
        use_cache (bool, optional): Whether to use cache, by default False

    Raises:
        ValueError:
            If there is some field that does not exist in data. 
        ValueError:
            If there is some field in converters that does not exist in fields.
    """

    def __init__(self,
                 data: List[Dict[str, Any]],
                 fields: List[str]=None,
                 converters: Dict[str, Callable]=None,
                 use_cache: bool=False):
        # metadata
        self.data = data
        assert len(data) > 0, "This dataset has no examples"

        # peak an example to get existing fields.
        first_example = self.data[0]
        fields_in_data = first_example.keys()

        # check all the requested fields exist
        if fields is None:
            self.fields = fields_in_data
        else:
            for field in fields:
                if field not in fields_in_data:
                    raise ValueError(
                        f"The requested field ({field}) is not found"
                        f"in the data. Fields in the data is {fields_in_data}")
            self.fields = fields

        # check converters
        if converters is None:
            self.converters = {}
        else:
            for field in converters.keys():
                if field not in self.fields:
                    raise ValueError(
                        f"The converter has a non existing field ({field})")
            self.converters = converters

        self.use_cache = use_cache
        if use_cache:
            self._initialize_cache()

    def _initialize_cache(self):
        self.manager = Manager()
        self.caches = self.manager.list()
        self.caches += [None for _ in range(len(self))]

    def _get_metadata(self, idx: int) -> Dict[str, Any]:
        """Return a meta-datum given an index."""
        return self.data[idx]

    def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]:
        """Convert a meta datum to an example by applying the corresponding 
        converters to each fields requested.

        Args:
            meta_datum (Dict[str, Any]): Meta datum

        Returns:
            Dict[str, Any]: Converted example
        """
        example = {}
        for field in self.fields:
            converter = self.converters.get(field, None)
            meta_datum_field = meta_datum[field]
            if converter is not None:
                converted_field = converter(meta_datum_field)
            else:
                converted_field = meta_datum_field
            example[field] = converted_field
        return example

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get an example given an index.
        Args:
            idx (int): Index of the example to get

        Returns:
            Dict[str, Any]: A converted example
        """
        if self.use_cache and self.caches[idx] is not None:
            return self.caches[idx]

        meta_datum = self._get_metadata(idx)
        example = self._convert(meta_datum)

        if self.use_cache:
            self.caches[idx] = example

        return example

    def __len__(self) -> int:
        """Returns the size of the dataset.

        Returns
        -------
        int
            The length of the dataset
        """
        return len(self.data)


class StarGANv2VCDataTable(DataTable):
    def __init__(self, data: List[Dict[str, Any]]):
        super().__init__(data)
        raw_data = data
        spk_id_set = list(set([item['spk_id'] for item in raw_data]))
        data_list_per_class = {}
        for spk_id in spk_id_set:
            data_list_per_class[spk_id] = []
        for item in raw_data:
            for spk_id in spk_id_set:
                if item['spk_id'] == spk_id:
                    data_list_per_class[spk_id].append(item)
        self.data_list_per_class = data_list_per_class

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get an example given an index.
        Args:
            idx (int): Index of the example to get

        Returns:
            Dict[str, Any]: A converted example
        """
        if self.use_cache and self.caches[idx] is not None:
            return self.caches[idx]

        data = self._get_metadata(idx)

        # 裁剪放到 batch_fn 里面
        # 返回一个字典
        """
        {'utt_id': 'p225_111', 'spk_id': '1', 'speech': 'path of *.npy'}
        """
        ref_data = random.choice(self.data)
        ref_label = ref_data['spk_id']
        ref_data_2 = random.choice(self.data_list_per_class[ref_label])
        # mel_tensor, label, ref_mel_tensor, ref2_mel_tensor, ref_label
        new_example = {
            'utt_id': data['utt_id'],
            'mel': np.load(data['speech']),
            'label': int(data['spk_id']),
            'ref_mel': np.load(ref_data['speech']),
            'ref_mel_2': np.load(ref_data_2['speech']),
            'ref_label': int(ref_label)
        }

        if self.use_cache:
            self.caches[idx] = new_example

        return new_example
